"""
EBA control variables assembly module.

Merges data from WEO, PWT, WDI, KAOPEN, and EWN into a unified
country-year panel with all EBA (External Balance Assessment)
control variables needed for the current account model.

EBA Controls:
- Fiscal balance / GDP (WEO)
- Lagged NFA / GDP (EWN)
- Relative output per worker (PWT)
- Public health expenditure / GDP (WDI)
- Capital openness / KAOPEN (Chinn-Ito)
- Oil/gas trade balance × resource temporariness (WEO)
- Expected GDP growth (WEO)
- Output gap (WEO)
- Life expectancy (WDI)
- Life expectancy × future OADR (WDI × demographics)
"""

import pandas as pd
import numpy as np
from pathlib import Path

RAW_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/raw")
PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")

# ISO2 to ISO3 mapping for common countries (used for KAOPEN/IFS merging)
ISO2_TO_ISO3 = {
    'US': 'USA', 'GB': 'GBR', 'DE': 'DEU', 'FR': 'FRA', 'JP': 'JPN',
    'CA': 'CAN', 'AU': 'AUS', 'IT': 'ITA', 'ES': 'ESP', 'NL': 'NLD',
    'BE': 'BEL', 'AT': 'AUT', 'CH': 'CHE', 'SE': 'SWE', 'NO': 'NOR',
    'DK': 'DNK', 'FI': 'FIN', 'IE': 'IRL', 'PT': 'PRT', 'GR': 'GRC',
    'NZ': 'NZL', 'KR': 'KOR', 'MX': 'MEX', 'TR': 'TUR', 'PL': 'POL',
    'CZ': 'CZE', 'HU': 'HUN', 'IL': 'ISR', 'CL': 'CHL', 'ZA': 'ZAF',
    'CO': 'COL', 'IS': 'ISL', 'LU': 'LUX', 'SK': 'SVK', 'SI': 'SVN',
    'EE': 'EST', 'LV': 'LVA', 'LT': 'LTU', 'IN': 'IND', 'CN': 'CHN',
    'BR': 'BRA', 'RU': 'RUS', 'ID': 'IDN', 'TH': 'THA', 'MY': 'MYS',
    'PH': 'PHL', 'VN': 'VNM', 'NG': 'NGA', 'KE': 'KEN', 'GH': 'GHA',
    'TZ': 'TZA', 'ET': 'ETH', 'EG': 'EGY', 'MA': 'MAR', 'DZ': 'DZA',
    'TN': 'TUN', 'SA': 'SAU', 'AE': 'ARE', 'PK': 'PAK', 'BD': 'BGD',
    'AR': 'ARG', 'PE': 'PER', 'SG': 'SGP', 'HK': 'HKG', 'TW': 'TWN',
    'RO': 'ROU', 'BG': 'BGR', 'HR': 'HRV', 'RS': 'SRB', 'UA': 'UKR',
}


def load_weo():
    """Load and process WEO data."""
    fpath = RAW_DIR / "weo_data.csv"
    if not fpath.exists():
        print("WEO data not found. Run download.py first.")
        return None
    df = pd.read_csv(fpath)
    df['year'] = df['year'].astype(int)
    return df


def load_pwt():
    """Load and process Penn World Tables data."""
    fpath = RAW_DIR / "pwt1001.xlsx"
    if not fpath.exists():
        print("PWT data not found. Run download.py first.")
        return None

    df = pd.read_excel(fpath, sheet_name='Data')

    # Key variables: countrycode, year, rgdpo (real GDP output-side),
    # emp (employment), hc (human capital), pop
    cols_needed = {
        'countrycode': 'iso3',
        'year': 'year',
        'rgdpo': 'rgdp_output',   # Real GDP at output-side PPP
        'emp': 'employment',       # Number of persons employed (millions)
        'hc': 'human_capital',     # Human capital index
        'pop': 'pop_pwt',          # Population (millions)
        'csh_x': 'export_share',   # Share of exports in GDP
        'csh_m': 'import_share',   # Share of imports in GDP
    }

    available = {k: v for k, v in cols_needed.items() if k in df.columns}
    df = df[list(available.keys())].rename(columns=available)

    # Compute output per worker
    if 'rgdp_output' in df.columns and 'employment' in df.columns:
        df['output_per_worker'] = df['rgdp_output'] / df['employment']

    return df


def load_wdi():
    """Load World Bank WDI data."""
    fpath = RAW_DIR / "wdi_data.csv"
    if not fpath.exists():
        print("WDI data not found. Run download.py first.")
        return None
    df = pd.read_csv(fpath)
    df['year'] = df['year'].astype(int)
    return df


def load_kaopen():
    """Load Chinn-Ito KAOPEN data."""
    fpath = RAW_DIR / "kaopen.csv"
    if not fpath.exists():
        print("KAOPEN data not found. Run download.py first.")
        return None
    df = pd.read_csv(fpath)

    # The KAOPEN file has columns: cn, ccode, country, year, kaopen, kaopen.1
    # ccode contains ISO3 codes (e.g., 'USA')
    if 'iso3' not in df.columns:
        if 'ccode' in df.columns:
            df = df.rename(columns={'ccode': 'iso3'})
        elif 'iso2' in df.columns:
            df['iso3'] = df['iso2'].map(ISO2_TO_ISO3)

    if 'kaopen' not in df.columns:
        for c in df.columns:
            if 'kaopen' in c.lower() or 'ka_open' in c.lower():
                df = df.rename(columns={c: 'kaopen'})
                break

    # Use 'kaopen' (first one) not 'kaopen.1' (normalized 0-1 version)
    return df


def load_pensions():
    """Load pension spending and coverage data."""
    fpath = RAW_DIR / "oecd_pensions.csv"
    if not fpath.exists():
        print("Pension data not found. Run download.py first.")
        return None
    df = pd.read_csv(fpath)
    df['year'] = df['year'].astype(int)
    return df


def load_savings_investment():
    """Load WDI savings and investment data."""
    fpath = RAW_DIR / "wdi_savings_investment.csv"
    if not fpath.exists():
        print("WDI savings/investment data not found. Run download.py first.")
        return None
    df = pd.read_csv(fpath)
    df['year'] = df['year'].astype(int)
    return df


def load_ewn():
    """Load Lane & Milesi-Ferretti External Wealth of Nations data."""
    fpath = RAW_DIR / "ewn.csv"
    if not fpath.exists():
        print("EWN data not found. Run download.py first.")
        return None

    df = pd.read_csv(fpath)

    # EWN has: Country, IFS_Code, Year, Net IIP excl gold, GDP (US$), etc.
    rename = {'Year': 'year', 'IFS_Code': 'ifs_code'}
    if 'Net IIP excl gold' in df.columns:
        rename['Net IIP excl gold'] = 'nfa'
    if 'GDP (US$)' in df.columns:
        rename['GDP (US$)'] = 'gdp_usd_ewn'
    if 'Total assets' in df.columns:
        rename['Total assets'] = 'total_assets'
    if 'Total liabilities' in df.columns:
        rename['Total liabilities'] = 'total_liabilities'
    if 'net IIP excl gold / GDP domestic currency' in df.columns:
        rename['net IIP excl gold / GDP domestic currency'] = 'nfa_gdp'

    # Asset/liability decomposition
    ewn_decomp = {
        'Portfolio equity assets (stock)': 'port_eq_assets',
        'Portfolio equity liabilities (stock)': 'port_eq_liab',
        'FDI assets (stock)': 'fdi_assets',
        'FDI liabilities (stock)': 'fdi_liab',
        'Debt assets (stock)': 'debt_assets',
        'Debt liabilities (stock)': 'debt_liab',
        'FX Reserves minus gold': 'fx_reserves',
    }
    for old_name, new_name in ewn_decomp.items():
        if old_name in df.columns:
            rename[old_name] = new_name

    df = df.rename(columns=rename)

    # Compute gross positions as % of GDP
    if 'gdp_usd_ewn' in df.columns:
        gdp = df['gdp_usd_ewn'].replace(0, np.nan)
        if 'total_assets' in df.columns:
            df['gross_assets_gdp'] = df['total_assets'] / gdp * 100
        if 'total_liabilities' in df.columns:
            df['gross_liab_gdp'] = df['total_liabilities'] / gdp * 100
        if 'total_assets' in df.columns and 'total_liabilities' in df.columns:
            df['gross_ifi'] = (df['total_assets'] + df['total_liabilities']) / gdp * 100
        if 'fdi_assets' in df.columns:
            df['fdi_assets_gdp'] = df['fdi_assets'] / gdp * 100
        if 'fdi_liab' in df.columns:
            df['fdi_liab_gdp'] = df['fdi_liab'] / gdp * 100
        if 'port_eq_assets' in df.columns:
            df['port_eq_assets_gdp'] = df['port_eq_assets'] / gdp * 100
        if 'debt_assets' in df.columns:
            df['debt_assets_gdp'] = df['debt_assets'] / gdp * 100
        if 'debt_liab' in df.columns:
            df['debt_liab_gdp'] = df['debt_liab'] / gdp * 100
        if 'fx_reserves' in df.columns:
            df['fx_reserves_gdp'] = df['fx_reserves'] / gdp * 100

    # Map IFS codes to ISO3 (IFS codes are numeric, need a mapping)
    # Build from the WEO country list if available
    ifs_to_iso3 = _build_ifs_to_iso3_map()
    if ifs_to_iso3:
        df['iso3'] = df['ifs_code'].map(ifs_to_iso3)
    else:
        # Fallback: try to map country names
        df['iso3'] = df['Country'].map(_country_name_to_iso3())

    return df


def _build_ifs_to_iso3_map():
    """Build IFS numeric code to ISO3 mapping.

    Authoritative mapping derived from EWN (Lane & Milesi-Ferretti) dataset
    country names, cross-checked against IFS_Code field. 212 entries covering
    all countries in the EWN database (excluding aggregates: Euro Area, ECCU).
    """
    return {
        # Advanced economies
        111: 'USA', 112: 'GBR', 122: 'AUT', 124: 'BEL', 128: 'DNK',
        132: 'FRA', 134: 'DEU', 135: 'SMR', 136: 'ITA', 137: 'LUX',
        138: 'NLD', 142: 'NOR', 144: 'SWE', 146: 'CHE', 147: 'LIE',
        156: 'CAN', 158: 'JPN', 171: 'AND', 172: 'FIN', 174: 'GRC',
        176: 'ISL', 178: 'IRL', 181: 'MLT', 182: 'PRT', 184: 'ESP',
        186: 'TUR', 193: 'AUS', 196: 'NZL', 199: 'ZAF',
        # Latin America & Caribbean
        213: 'ARG', 218: 'BOL', 223: 'BRA', 228: 'CHL', 233: 'COL',
        238: 'CRI', 243: 'DOM', 248: 'ECU', 253: 'SLV', 258: 'GTM',
        263: 'HTI', 268: 'HND', 273: 'MEX', 278: 'NIC', 283: 'PAN',
        288: 'PRY', 293: 'PER', 298: 'URY', 299: 'VEN',
        311: 'ATG', 312: 'AIA', 313: 'BHS', 314: 'ABW', 316: 'BRB',
        319: 'BMU', 321: 'DMA', 328: 'GRD', 336: 'GUY', 339: 'BLZ',
        343: 'JAM', 351: 'MSR', 352: 'SXM', 353: 'ANT', 354: 'CUW',
        361: 'KNA', 362: 'LCA', 364: 'VCT', 366: 'SUR', 369: 'TTO',
        371: 'VGB', 377: 'CYM', 381: 'TCA',
        # Middle East
        419: 'BHR', 423: 'CYP', 429: 'IRN', 433: 'IRQ', 436: 'ISR',
        439: 'JOR', 443: 'KWT', 446: 'LBN', 449: 'OMN', 453: 'QAT',
        456: 'SAU', 463: 'SYR', 466: 'ARE', 469: 'EGY', 474: 'YEM',
        487: 'PSE',
        # Asia & Pacific
        512: 'AFG', 513: 'BGD', 514: 'BTN', 516: 'BRN', 518: 'MMR',
        522: 'KHM', 524: 'LKA', 528: 'TWN', 532: 'HKG', 534: 'IND',
        536: 'IDN', 537: 'TLS', 542: 'KOR', 544: 'LAO', 546: 'MAC',
        548: 'MYS', 556: 'MDV', 558: 'NPL', 564: 'PAK', 565: 'PLW',
        566: 'PHL', 576: 'SGP', 578: 'THA', 582: 'VNM',
        # Africa
        611: 'DJI', 612: 'DZA', 614: 'AGO', 616: 'BWA', 618: 'BDI',
        622: 'CMR', 624: 'CPV', 626: 'CAF', 628: 'TCD', 632: 'COM',
        634: 'COG', 636: 'COD', 638: 'BEN', 642: 'GNQ', 643: 'ERI',
        644: 'ETH', 646: 'GAB', 648: 'GMB', 652: 'GHA', 654: 'GNB',
        656: 'GIN', 662: 'CIV', 664: 'KEN', 666: 'LSO', 668: 'LBR',
        672: 'LBY', 674: 'MDG', 676: 'MWI', 678: 'MLI', 682: 'MRT',
        684: 'MUS', 686: 'MAR', 688: 'MOZ', 692: 'NER', 694: 'NGA',
        698: 'ZWE', 714: 'RWA', 716: 'STP', 718: 'SYC', 722: 'SEN',
        724: 'SLE', 726: 'SOM', 728: 'NAM', 732: 'SDN', 733: 'SSD',
        734: 'SWZ', 738: 'TZA', 742: 'TGO', 744: 'TUN', 746: 'UGA',
        748: 'BFA', 754: 'ZMB',
        # Pacific Islands
        813: 'SLB', 816: 'FRO', 819: 'FJI', 823: 'GIB', 826: 'KIR',
        836: 'NRU', 839: 'NCL', 846: 'VUT', 853: 'PNG', 862: 'WSM',
        866: 'TON', 867: 'MHL', 868: 'FSM', 869: 'TUV', 887: 'PYF',
        # CIS & Eastern Europe
        911: 'ARM', 912: 'AZE', 913: 'BLR', 914: 'ALB', 915: 'GEO',
        916: 'KAZ', 917: 'KGZ', 918: 'BGR', 921: 'MDA', 922: 'RUS',
        923: 'TJK', 924: 'CHN', 925: 'TKM', 926: 'UKR', 927: 'UZB',
        935: 'CZE', 936: 'SVK', 939: 'EST', 941: 'LVA', 942: 'SRB',
        943: 'MNE', 944: 'HUN', 946: 'LTU', 948: 'MNG', 960: 'HRV',
        961: 'SVN', 962: 'MKD', 963: 'BIH', 964: 'POL', 967: 'XKX',
        968: 'ROU',
    }


def _country_name_to_iso3():
    """Fallback: map common country names to ISO3."""
    return {
        'United States': 'USA', 'United Kingdom': 'GBR', 'Germany': 'DEU',
        'France': 'FRA', 'Japan': 'JPN', 'Canada': 'CAN', 'Australia': 'AUS',
        'Italy': 'ITA', 'Spain': 'ESP', 'Netherlands': 'NLD',
        'China,P.R.: Mainland': 'CHN', 'China, P.R.: Mainland': 'CHN',
        'India': 'IND', 'Brazil': 'BRA', 'Russia': 'RUS',
        'Russian Federation': 'RUS', 'Korea, Republic of': 'KOR',
        'Korea': 'KOR', 'Mexico': 'MEX', 'Indonesia': 'IDN',
        'South Africa': 'ZAF', 'Turkey': 'TUR',
        'Saudi Arabia': 'SAU', 'Nigeria': 'NGA',
    }


def compute_relative_output_per_worker(pwt_df):
    """
    Compute output per worker relative to US (or GDP-weighted world average).
    Following EBA methodology.
    """
    if pwt_df is None or 'output_per_worker' not in pwt_df.columns:
        return None

    df = pwt_df[['iso3', 'year', 'output_per_worker']].dropna()

    # Compute relative to US
    us = df[df['iso3'] == 'USA'][['year', 'output_per_worker']].rename(
        columns={'output_per_worker': 'opw_usa'}
    )
    df = df.merge(us, on='year', how='left')
    df['rel_output_per_worker'] = df['output_per_worker'] / df['opw_usa']

    # Log of relative output per worker (common in EBA)
    df['log_rel_opw'] = np.log(df['rel_output_per_worker'].clip(lower=0.001))

    return df[['iso3', 'year', 'output_per_worker', 'rel_output_per_worker', 'log_rel_opw']]


def assemble_macro_panel():
    """
    Assemble the full macro control variable panel.

    Merges WEO, PWT, WDI, KAOPEN, and EWN data into a unified
    country-year panel.
    """
    print("Assembling macro control variable panel...")

    # Load all sources
    weo = load_weo()
    pwt = load_pwt()
    wdi = load_wdi()
    kaopen = load_kaopen()
    ewn = load_ewn()

    # Start with WEO as the base (widest country-year coverage)
    if weo is not None:
        panel = weo.copy()
        print(f"  WEO base: {panel.shape}")
    else:
        # Create empty panel from PWT or WDI
        panel = pd.DataFrame(columns=['iso3', 'year'])

    # Merge PWT
    if pwt is not None:
        rel_opw = compute_relative_output_per_worker(pwt)
        if rel_opw is not None:
            panel = panel.merge(rel_opw, on=['iso3', 'year'], how='left')
            print(f"  After PWT merge: {panel.shape}")

        # Also get human capital from PWT
        if 'human_capital' in pwt.columns:
            panel = panel.merge(
                pwt[['iso3', 'year', 'human_capital']].dropna(),
                on=['iso3', 'year'], how='left'
            )

    # Merge WDI
    if wdi is not None:
        panel = panel.merge(wdi, on=['iso3', 'year'], how='left')
        print(f"  After WDI merge: {panel.shape}")

    # Merge KAOPEN
    if kaopen is not None and 'kaopen' in kaopen.columns:
        kaopen_clean = kaopen[['iso3', 'year', 'kaopen']].dropna(subset=['iso3'])
        panel = panel.merge(kaopen_clean, on=['iso3', 'year'], how='left')
        print(f"  After KAOPEN merge: {panel.shape}")

    # Merge EWN (NFA/GDP + gross position decomposition)
    if ewn is not None:
        ewn_cols = ['iso3', 'year']
        ewn_wanted = ['nfa_gdp', 'nfa', 'gross_assets_gdp', 'gross_liab_gdp', 'gross_ifi',
                       'fdi_assets_gdp', 'fdi_liab_gdp', 'port_eq_assets_gdp',
                       'debt_assets_gdp', 'debt_liab_gdp', 'fx_reserves_gdp']
        for c in ewn_wanted:
            if c in ewn.columns:
                ewn_cols.append(c)
        ewn_clean = ewn[ewn_cols].dropna(subset=['iso3'])
        panel = panel.merge(ewn_clean, on=['iso3', 'year'], how='left')
        print(f"  After EWN merge: {panel.shape}")

    # Merge savings/investment
    sav_inv = load_savings_investment()
    if sav_inv is not None:
        panel = panel.merge(sav_inv, on=['iso3', 'year'], how='left')
        # Compute savings-investment gap (should ≈ CA)
        if 'gross_savings_gdp' in panel.columns and 'gross_investment_gdp' in panel.columns:
            panel['savings_investment_gap'] = panel['gross_savings_gdp'] - panel['gross_investment_gdp']
        print(f"  After savings/investment merge: {panel.shape}")

    # Merge pensions
    pensions = load_pensions()
    if pensions is not None:
        pension_cols = ['iso3', 'year']
        for pc in ['pension_spending_gdp', 'pension_coverage']:
            if pc in pensions.columns:
                pension_cols.append(pc)
        pension_clean = pensions[pension_cols].dropna(subset=['iso3'])
        panel = panel.merge(pension_clean, on=['iso3', 'year'], how='left')
        print(f"  After pension merge: {panel.shape}")

    # --- Construct derived EBA variables ---

    # Lagged NFA/GDP
    if 'nfa_gdp' in panel.columns:
        panel = panel.sort_values(['iso3', 'year'])
        panel['nfa_gdp_lag'] = panel.groupby('iso3')['nfa_gdp'].shift(1)

    # Expected GDP growth (5-year ahead average from WEO forecasts)
    # Use forward-looking average of rgdp_growth
    if 'rgdp_growth' in panel.columns:
        panel = panel.sort_values(['iso3', 'year'])
        panel['expected_growth'] = (
            panel.groupby('iso3')['rgdp_growth']
            .transform(lambda x: x.shift(-1).rolling(5, min_periods=1).mean())
        )

    # Trade openness (from WDI or computed from PWT)
    if 'trade_openness' not in panel.columns and pwt is not None:
        if 'export_share' in pwt.columns and 'import_share' in pwt.columns:
            trade = pwt[['iso3', 'year', 'export_share', 'import_share']].copy()
            trade['trade_openness'] = (trade['export_share'].abs() + trade['import_share'].abs()) * 100
            panel = panel.merge(trade[['iso3', 'year', 'trade_openness']], on=['iso3', 'year'], how='left')

    # Save
    outfile = PROCESSED_DIR / "macro_panel.csv"
    panel.to_csv(outfile, index=False)
    print(f"  Saved macro panel: {panel.shape} to {outfile}")

    # Summary statistics
    print("\n  Variable coverage:")
    for col in panel.columns:
        if col not in ['iso3', 'year']:
            n = panel[col].notna().sum()
            print(f"    {col}: {n:,} non-null ({n/len(panel)*100:.1f}%)")

    return panel


# ---------------------------------------------------------------------------
# EBA 49 country list
# ---------------------------------------------------------------------------

EBA_COUNTRIES = [
    'USA', 'GBR', 'DEU', 'FRA', 'JPN', 'CAN', 'AUS', 'ITA', 'ESP', 'NLD',
    'BEL', 'AUT', 'CHE', 'SWE', 'NOR', 'DNK', 'FIN', 'IRL', 'PRT', 'GRC',
    'NZL', 'KOR', 'MEX', 'TUR', 'POL', 'CZE', 'HUN', 'ISR', 'CHL', 'ZAF',
    'COL', 'SGP', 'HKG', 'TWN', 'MYS', 'THA', 'PHL', 'IDN', 'IND', 'CHN',
    'BRA', 'ARG', 'RUS', 'SAU', 'ARE', 'EGY', 'PAK', 'MAR', 'PER',
]

# Additional Sub-Saharan African countries for extended sample
SSA_COUNTRIES = [
    'NGA', 'KEN', 'GHA', 'TZA', 'ETH', 'UGA', 'MOZ', 'SEN', 'CIV',
    'CMR', 'AGO', 'ZMB', 'ZWE', 'BWA', 'MUS', 'NAM', 'RWA', 'MDG',
    'MWI', 'BFA',
]


def filter_eba_sample(panel, extended=True):
    """Filter panel to EBA countries, optionally with SSA extension."""
    countries = EBA_COUNTRIES.copy()
    if extended:
        countries.extend(SSA_COUNTRIES)

    filtered = panel[panel['iso3'].isin(countries)].copy()
    print(f"  Filtered to {'extended ' if extended else ''}EBA sample: "
          f"{filtered['iso3'].nunique()} countries, {len(filtered):,} obs")
    return filtered


if __name__ == "__main__":
    panel = assemble_macro_panel()
    print(f"\nPanel shape: {panel.shape}")
    print(f"Countries: {panel['iso3'].nunique()}")
    print(f"Years: {panel['year'].min()}-{panel['year'].max()}")
